package com.fary.net;

import com.fary.util.SynchronizedQueue;
import com.fary.util.SynchronizedStack;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.channels.*;
import java.util.Iterator;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

/**
 * @author Fary
 * @version 1.0
 * @description: TODO
 * @date 2022/1/22 14:51
 */
public class NioEndpoint extends AbstractEndpoint<NioChannel> {

    public static final int OP_REGISTER = 0x100;

    private long selectorTimeout = 1000;

    private AtomicInteger pollerRotater = new AtomicInteger(0);

    private int pollerThreadCount = Math.min(2, Runtime.getRuntime().availableProcessors());

    private NioSelectorPool selectorPool = new NioSelectorPool();

    private volatile ServerSocketChannel serverSock = null;

    private Poller[] pollers = null;

    private volatile CountDownLatch stopLatch = null;

    private SynchronizedStack<PollerEvent> eventCache;

    private SynchronizedStack<NioChannel> nioChannels;

    public Poller getPoller0() {
        int idx = Math.abs(pollerRotater.incrementAndGet()) % pollers.length;
        return pollers[idx];
    }

    public NioSelectorPool getSelectorPool() {
        return selectorPool;
    }

    public void setPollerThreadCount(int pollerThreadCount) {
        this.pollerThreadCount = pollerThreadCount;
    }

    public int getPollerThreadCount() {
        return pollerThreadCount;
    }

    protected void setStopLatch(CountDownLatch stopLatch) {
        this.stopLatch = stopLatch;
    }

    protected CountDownLatch getStopLatch() {
        return stopLatch;
    }

    protected boolean setSocketOptions(SocketChannel socket) {
        try {
            socket.configureBlocking(false);
            Socket sock = socket.socket();
            socketProperties.setProperties(sock);

            NioChannel channel = nioChannels.pop();
            if (channel == null) {
                SocketBufferHandler bufhandler = new SocketBufferHandler(socketProperties.getAppReadBufSize(),
                        socketProperties.getAppWriteBufSize(),
                        socketProperties.getDirectBuffer());
                channel = new NioChannel(socket, bufhandler);
            } else {
                channel.setIOChannel(socket);
                channel.reset();
            }
            getPoller0().register(channel);
        } catch (Throwable t) {
            t.printStackTrace();
            return false;
        }
        return true;
    }

    @Override
    protected AbstractEndpoint.Acceptor createAcceptor() {
        return new Acceptor();
    }

    @Override
    public void bind() {
        try {
            serverSock = ServerSocketChannel.open();
            socketProperties.setProperties(serverSock.socket());
            InetSocketAddress addr = new InetSocketAddress(getPort());
            serverSock.socket().bind(addr, getAcceptCount());
            serverSock.configureBlocking(true);
        } catch (Exception e) {
            e.printStackTrace();
        }

        // 初始化接受器、轮询器的线程计数默认值、多个接受线程似乎效果不佳
        if (acceptorThreadCount == 0) {
            acceptorThreadCount = 1;
        }
        if (pollerThreadCount <= 0) {
            //minimum one poller thread
            pollerThreadCount = 1;
        }
        setStopLatch(new CountDownLatch(pollerThreadCount));

        selectorPool.open();
    }

    public void startInternal() {
        if (!running) {
            running = true;
            paused = false;

            processorCache = new SynchronizedStack<>(SynchronizedStack.DEFAULT_SIZE, socketProperties.getProcessorCache());
            eventCache = new SynchronizedStack<>(SynchronizedStack.DEFAULT_SIZE, socketProperties.getEventCache());
            nioChannels = new SynchronizedStack<>(SynchronizedStack.DEFAULT_SIZE, socketProperties.getBufferPool());

            // Create worker collection
            if (getExecutor() == null) {
                createExecutor();
            }

            initializeConnectionLatch();

            // Start poller threads
            pollers = new Poller[getPollerThreadCount()];
            for (int i = 0; i < pollers.length; i++) {
                pollers[i] = new Poller();
                Thread pollerThread = new Thread(pollers[i], getName() + "-ClientPoller-" + i);
                pollerThread.setPriority(threadPriority);
                pollerThread.setDaemon(true);
                pollerThread.start();
            }

            startAcceptorThreads();
        }
    }


    /**
     * Poller class.
     */
    public class Poller implements Runnable {

        private Selector selector;
        private final SynchronizedQueue<PollerEvent> events = new SynchronizedQueue<>();

        private volatile boolean close = false;
        private long nextExpiration = 0;//optimize expiration handling

        private AtomicLong wakeupCounter = new AtomicLong(0);

        private volatile int keyCount = 0;

        public Poller() {
            try {
                this.selector = Selector.open();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        public Selector getSelector() {
            return selector;
        }

        private void addEvent(PollerEvent event) {
            events.offer(event);
            if (wakeupCounter.incrementAndGet() == 0) {
                selector.wakeup();
            }
        }

        public boolean events() {
            boolean result = false;

            PollerEvent pe = null;
            for (int i = 0, size = events.size(); i < size && (pe = events.poll()) != null; i++) {
                result = true;
                try {
                    pe.run();
                    pe.reset();
                    if (running && !paused) {
                        eventCache.push(pe);
                    }
                } catch (Throwable e) {
                    e.printStackTrace();
                }
            }

            return result;
        }

        public void register(final NioChannel socket) {
            socket.setPoller(this);
            NioSocketWrapper ka = new NioSocketWrapper(socket, NioEndpoint.this);
            socket.setSocketWrapper(ka);
            ka.setPoller(this);
            ka.setReadTimeout(socketProperties.getSoTimeout());
            ka.setWriteTimeout(socketProperties.getSoTimeout());
            ka.setKeepAliveLeft(NioEndpoint.this.getMaxKeepAliveRequests());
            PollerEvent r = eventCache.pop();
            ka.interestOps(SelectionKey.OP_READ);//this is what OP_REGISTER turns into.
            if (r == null) {
                r = new PollerEvent(socket, ka, OP_REGISTER);
            } else {
                r.reset(socket, ka, OP_REGISTER);
            }
            addEvent(r);
        }

        public NioSocketWrapper cancelledKey(SelectionKey key) {
            NioSocketWrapper ka = null;
            try {
                if (key == null) {
                    return null;//nothing to do
                }
                ka = (NioSocketWrapper) key.attach(null);
                if (ka != null) {
                    getHandler().release(ka);
                }
                if (key.isValid()) {
                    key.cancel();
                }
                if (ka != null) {
                    ka.getSocket().close(true);
                }
                if (key.channel().isOpen()) {
                    key.channel().close();
                }
                if (ka != null) {
                    countDownConnection();
                    ka.closed = true;
                }
            } catch (Throwable e) {
                e.printStackTrace();
            }
            return ka;
        }

        @Override
        public void run() {
            // Loop until destroy() is called
            while (true) {

                boolean hasEvents = false;

                try {
                    if (!close) {
                        hasEvents = events();
                        if (wakeupCounter.getAndSet(-1) > 0) {
                            keyCount = selector.selectNow();
                        } else {
                            keyCount = selector.select(selectorTimeout);
                        }
                        wakeupCounter.set(0);
                    }
                    if (close) {
                        events();
//                        timeout(0, false);
                        try {
                            selector.close();
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                        break;
                    }
                } catch (Throwable x) {
                    continue;
                }
                // Either we timed out or we woke up, process events first
                if (keyCount == 0) {
                    hasEvents = (hasEvents | events());
                }

                Iterator<SelectionKey> iterator = keyCount > 0 ? selector.selectedKeys().iterator() : null;
                while (iterator != null && iterator.hasNext()) {
                    SelectionKey sk = iterator.next();
                    iterator.remove();
                    NioSocketWrapper socketWrapper = (NioSocketWrapper) sk.attachment();
                    if (socketWrapper != null) {
//                        processKey(sk, socketWrapper);
                    }
                }

//                timeout(keyCount, hasEvents);
            }

            getStopLatch().countDown();
        }
    }

    public static class PollerEvent implements Runnable {

        private NioChannel socket;
        private int interestOps;
        private NioSocketWrapper socketWrapper;

        public PollerEvent(NioChannel ch, NioSocketWrapper w, int intOps) {
            reset(ch, w, intOps);
        }

        public void reset(NioChannel ch, NioSocketWrapper w, int intOps) {
            socket = ch;
            interestOps = intOps;
            socketWrapper = w;
        }

        public void reset() {
            reset(null, null, 0);
        }

        @Override
        public void run() {
            if (interestOps == OP_REGISTER) {
                try {
                    socket.getIOChannel().register(socket.getPoller().getSelector(), SelectionKey.OP_READ, socketWrapper);
                } catch (Exception e) {
                    e.printStackTrace();
                }
            } else {
                final SelectionKey key = socket.getIOChannel().keyFor(socket.getPoller().getSelector());
                try {
                    if (key == null) {
                        socket.socketWrapper.getEndpoint().countDownConnection();
                        ((NioSocketWrapper) socket.socketWrapper).closed = true;
                    } else {
                        final NioSocketWrapper socketWrapper = (NioSocketWrapper) key.attachment();
                        if (socketWrapper != null) {
                            //we are registering the key to start with, reset the fairness counter.
                            int ops = key.interestOps() | interestOps;
                            socketWrapper.interestOps(ops);
                            key.interestOps(ops);
                        } else {
                            socket.getPoller().cancelledKey(key);
                        }
                    }
                } catch (CancelledKeyException ckx) {
                    try {
                        socket.getPoller().cancelledKey(key);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
        }

        @Override
        public String toString() {
            return "Poller event: socket [" + socket + "], socketWrapper [" + socketWrapper + "], interestOps [" + interestOps + "]";
        }
    }

    protected class Acceptor extends AbstractEndpoint.Acceptor {

        @Override
        public void run() {
            long pauseStart = 0;

            while (running) {
                // 暂停
                while (paused && running) {
                    if (state != AcceptorState.PAUSED) {
                        pauseStart = System.nanoTime();
                        state = AcceptorState.PAUSED;
                    }
                    if ((System.nanoTime() - pauseStart) > 1_000_000) {
                        try {
                            if ((System.nanoTime() - pauseStart) > 10_000_000) {
                                Thread.sleep(10);
                            } else {
                                Thread.sleep(1);
                            }
                        } catch (InterruptedException e) {
                            // Ignore
                        }
                    }
                }

                // 停止
                if (!running) {
                    break;
                }
                state = AcceptorState.RUNNING;

                try {
                    countUpOrAwaitConnection();

                    SocketChannel socket = null;
                    try {
                        socket = serverSock.accept();
                    } catch (IOException ioe) {
                        countDownConnection();
                        if (running) {
                            throw ioe;
                        } else {
                            break;
                        }
                    }
                    // Configure the socket
                    if (running && !paused) {
                        if (!setSocketOptions(socket)) {
                            closeSocket(socket);
                        }
                    } else {
                        closeSocket(socket);
                    }
                } catch (Throwable e) {
                    e.printStackTrace();
                }
            }
            state = AcceptorState.ENDED;
        }


        private void closeSocket(SocketChannel socket) {
            countDownConnection();
            try {
                socket.socket().close();
                socket.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public static class NioSocketWrapper extends SocketWrapperBase<NioChannel> {

        private final NioSelectorPool pool;

        private Poller poller = null;

        private int interestOps = 0;

        private volatile boolean closed = false;

        public NioSocketWrapper(NioChannel channel, NioEndpoint endpoint) {
            super(channel, endpoint);
            pool = endpoint.getSelectorPool();
            socketBufferHandler = channel.getBufHandler();
        }

        public void setPoller(Poller poller) {
            this.poller = poller;
        }

        public int interestOps(int ops) {
            this.interestOps = ops;
            return ops;
        }

        @Override
        public boolean isClosed() {
            return closed;
        }
    }

    protected class SocketProcessor extends SocketProcessorBase<NioChannel> {

        public SocketProcessor(SocketWrapperBase<NioChannel> socketWrapper, SocketEvent event) {
            super(socketWrapper, event);
        }

        @Override
        protected void doRun() {

        }

    }
}
